{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7acdd894-f602-4732-8310-e5c3b80ae35b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1102, -0.0530, -0.0823, -0.0572,  0.0371, -0.1223, -0.2750,  0.2275,\n",
       "          0.0352, -0.1847],\n",
       "        [-0.1268, -0.0228, -0.0331, -0.0430, -0.0009, -0.0499, -0.1481,  0.1189,\n",
       "          0.1074, -0.0780]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n",
    "\n",
    "X = torch.rand(2, 20)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "02f82147-5935-4e97-9364-da1c67e7e90a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    # 用模型参数声明层。这里，我们声明两个全连接的层\n",
    "    def __init__(self):\n",
    "        # 调用MLP的父类Module的构造函数来执行必要的初始化。\n",
    "        # 这样，在类实例化时也可以指定其他函数参数，例如模型参数params（稍后将介绍）\n",
    "        super().__init__()\n",
    "        self.hidden = nn.Linear(20, 256)  # 隐藏层\n",
    "        self.out = nn.Linear(256, 10)  # 输出层\n",
    "\n",
    "    # 定义模型的前向传播，即如何根据输入X返回所需的模型输出\n",
    "    def forward(self, X):\n",
    "        # 注意，这里我们使用ReLU的函数版本，其在nn.functional模块中定义。\n",
    "        return self.out(F.relu(self.hidden(X)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2c725a39-1b44-4774-b17b-848431073a02",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0155, -0.1538,  0.0419,  0.0646,  0.0941, -0.0906, -0.1583, -0.2175,\n",
       "          0.1967,  0.1549],\n",
       "        [ 0.0398, -0.0921, -0.0216,  0.1625,  0.1485, -0.1613, -0.0407, -0.0838,\n",
       "          0.2085,  0.1822]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MLP()\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6c4d53ed-d79a-4c88-b525-130c7cea3e2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MySequential(nn.Module):\n",
    "    def __init__(self, *args):\n",
    "        super().__init__()\n",
    "        for idx, module in enumerate(args):\n",
    "            # 这里，module是Module子类的一个实例。我们把它保存在'Module'类的成员\n",
    "            # 变量_modules中。_module的类型是OrderedDict\n",
    "            self._modules[str(idx)] = module\n",
    "\n",
    "    def forward(self, X):\n",
    "        # OrderedDict保证了按照成员添加的顺序遍历它们\n",
    "        for block in self._modules.values():\n",
    "            X = block(X)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0ffc382b-a147-4569-8629-eaf9dc2fb5cf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.3062,  0.2096,  0.0659,  0.0231, -0.0035, -0.0812, -0.3420, -0.0507,\n",
       "          0.0050, -0.2361],\n",
       "        [-0.1956,  0.1056,  0.0174,  0.0352, -0.1179,  0.0263, -0.2776, -0.1839,\n",
       "          0.0592, -0.2271]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MySequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "41f2c95b-543b-4a70-bbd8-f01c11f1ee15",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.1654, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class FixedHiddenMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # 不计算梯度的随机权重参数。因此其在训练期间保持不变\n",
    "        self.rand_weight = torch.rand((20, 20), requires_grad=False)\n",
    "        self.linear = nn.Linear(20, 20)\n",
    "\n",
    "    def forward(self, X):\n",
    "        X = self.linear(X)\n",
    "        # 使用创建的常量参数以及relu和mm函数\n",
    "        X = F.relu(torch.mm(X, self.rand_weight) + 1)\n",
    "        # 复用全连接层。这相当于两个全连接层共享参数\n",
    "        X = self.linear(X)\n",
    "        # 控制流\n",
    "        while X.abs().sum() > 1:\n",
    "            X /= 2\n",
    "        return X.sum()\n",
    "net = FixedHiddenMLP()\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e4f1cd93-ce2f-48e7-b727-c3132f1f2346",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-0.1362, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class NestMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),\n",
    "                                 nn.Linear(64, 32), nn.ReLU())\n",
    "        self.linear = nn.Linear(32, 16)\n",
    "\n",
    "    def forward(self, X):\n",
    "        return self.linear(self.net(X))\n",
    "\n",
    "chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())\n",
    "chimera(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b34b776-97c0-494a-9437-3c51bd33a519",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
